// Copyright by LunaSec (owned by Refinery Labs, Inc)
//
// Licensed under the Business Source License v1.1
// (the "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// https://github.com/lunasec-io/lunasec/blob/master/licenses/BSL-LunaTrace.txt
//
// See the License for the specific language governing permissions and
// limitations under the License.
package scrape

import (
	"context"
	"crypto/tls"
	"database/sql"
	"encoding/json"
	"net/http"
	"net/url"
	"os"
	"path"
	"strings"
	"sync"
	"time"

	"github.com/go-jet/jet/v2/postgres"

	"github.com/PullRequestInc/go-gpt3"
	"github.com/go-shiori/go-readability"
	"github.com/google/uuid"
	"github.com/mozillazg/go-slugify"
	"github.com/rs/zerolog/log"
	"github.com/schollz/progressbar/v3"
	"go.uber.org/fx"

	"golang.org/x/net/html"

	"github.com/lunasec-io/lunasec/lunatrace/bsl/ingest-worker/pkg/util"
	"github.com/lunasec-io/lunasec/lunatrace/gogen/proto/gen"
	packschem "github.com/lunasec-io/lunasec/lunatrace/gogen/sqlgen/lunatrace/package/table"
	"github.com/lunasec-io/lunasec/lunatrace/gogen/sqlgen/lunatrace/vulnerability/model"
	"github.com/lunasec-io/lunasec/lunatrace/gogen/sqlgen/lunatrace/vulnerability/table"
)

var Module = fx.Options(
	fx.Provide(
		NewConfig,
		NewScraper,
	),
)

type Scraper interface {
	ScrapeVulnerabilities(ecosystem, vulnID string, onlyUnfetchedContent bool) error
	LoadAndOutputToDir(cache string, outputDir string, markdown bool) error
	ScrapeURLWithChrome(url string) (*ScrapeResponse, error)
}

type scraperDeps struct {
	fx.In
	Config

	DB           *sql.DB
	OpenAIClient gpt3.Client
	LangChain    gen.LangChainClient
}

type scraper struct {
	deps           scraperDeps
	httpClient     *http.Client
	browserDomains []string
}

type ReferenceContentAndSnippets struct {
	model.ReferenceContent
	snippets []model.CodeSnippet
}

func (p *scraper) scrapeVulnerabilityReference(ref *ReferenceInfo) *ReferenceContentAndSnippets {
	pRef := ReferenceContentAndSnippets{
		ReferenceContent: model.ReferenceContent{ReferenceID: ref.ID},
		snippets:         make([]model.CodeSnippet, 0),
	}

	resp, err := p.scrapeContent(ref.URL)
	if err != nil {
		log.Error().Err(err).Str("url", ref.URL).Msg("failed to scrape Content")
		return &pRef
	}

	pRef.LastSuccessfulFetch = util.Ptr(time.Now())
	pRef.Content = resp.Content
	pRef.Title = resp.Title
	pRef.ContentType = resp.ContentType
	// TODO break this out into its own step so that normalization is done completely after scraping
	normalContent, err := normalizeReferenceContent(resp.Content)
	if err != nil {
		log.Error().
			Err(err).
			Str("url", ref.URL).
			Msg("failed to normalize Content")
		return &pRef
	}

	pRef.NormalizedContent = normalContent

	// call our python langchain llm service to populate some LLM derived fields
	cleanResult, err := p.deps.LangChain.CleanAdvisory(context.Background(), &gen.CleanAdvisoryRequest{
		Content:     normalContent,
		Description: ref.VulnDesc,
	})

	if err != nil {
		log.Error().Err(err).Msg("Error getting advisory cleaning LLM result from python langchain service")
		return &pRef
	}

	pRef.Summary = &cleanResult.Summary
	pRef.ParsedContent = &cleanResult.Content

	if resp.ContentType == "text/html" {
		snippets, err := extractCodeFromHTMLString(resp.Content, ref.VulnDesc)
		if err != nil {
			log.Error().Err(err).Msg("Error parsing code snippets from HTML, probably bad html")
			return &pRef
		}
		if len(snippets) > 0 {
			codeResult, err := p.deps.LangChain.CleanSnippets(context.Background(), &gen.CleanSnippetsRequest{
				Snippets: snippets,
			})
			if err != nil {
				log.Error().Err(err).Msg("Error getting snippet processing result from LLM python langchain service")
				return &pRef
			}

			// loop through the codeResults and make a slice full of pointers to the codesnippets model
			for _, llmSnip := range codeResult.Snippets {
				// skip all the junky ones the llm didnt like very much
				if llmSnip.Score < 50 {
					continue
				}
				completeSnippet := model.CodeSnippet{
					ID:            uuid.UUID{},
					CreatedAt:     time.Time{},
					ReferenceID:   ref.ID,
					SourceURL:     ref.URL,
					Vulnerability: ref.VulnerabilityID,
					Code:          llmSnip.Code,
					Score:         llmSnip.Score,
					Summary:       llmSnip.Summary,
					Type:          llmSnip.Type,
					Language:      llmSnip.Language,
				}
				pRef.snippets = append(pRef.snippets, completeSnippet)
			}
		}
	}

	return &pRef
}

func extractCode(n *html.Node, vulnDesc string, codeTexts []*gen.Snippet) []*gen.Snippet {
	if n.Type == html.ElementNode && n.Data == "code" {
		// TODO: make this code also find preambles, a reference implementation can be seen in ml/js/index.ts
		if len(n.FirstChild.Data) > 30 {
			codeTexts = append(codeTexts, &gen.Snippet{Code: n.FirstChild.Data, VulnDescription: vulnDesc, Preamble: "No Preamble Available, do your best without one."})
		}
	}
	for c := n.FirstChild; c != nil; c = c.NextSibling {
		codeTexts = extractCode(c, vulnDesc, codeTexts)
	}
	return codeTexts
}

func extractCodeFromHTMLString(htmlString string, vulnDesc string) ([]*gen.Snippet, error) {
	doc, err := html.Parse(strings.NewReader(htmlString))
	if err != nil {
		return nil, err
	}

	codeTexts := make([]*gen.Snippet, 0)
	codeTexts = extractCode(doc, vulnDesc, codeTexts)

	return codeTexts, nil
}

func (p *scraper) processVulnerabilityWorker(
	wg *sync.WaitGroup,
	refScrapeChan <-chan *ReferenceInfo,
	saveRefChan chan<- *ReferenceContentAndSnippets,
) error {
	for ref := range refScrapeChan {
		scrapedRef := p.scrapeVulnerabilityReference(ref)
		saveRefChan <- scrapedRef
	}
	wg.Done()
	return nil
}

func (p *scraper) referenceContentAlreadyExists(referenceID string) error {
	referenceUUID, err := uuid.Parse(referenceID)
	if err != nil {
		return err
	}

	rc := table.ReferenceContent
	// TODO: Why are we doing these joins when it seems that we could just select by the UUID directly on the reference_content directly?
	selectExistingRef := rc.LEFT_JOIN(
		table.Reference, table.Reference.ID.EQ(rc.ReferenceID),
	).SELECT(
		table.Reference.ID,
	).WHERE(
		table.Reference.ID.EQ(postgres.UUID(referenceUUID)),
	)

	var existingVulnRef model.ReferenceContent
	return selectExistingRef.Query(p.deps.DB, &existingVulnRef)
}

func (p *scraper) updateOrCreateRefContentWorker(
	saveWg *sync.WaitGroup,
	saveRefChan <-chan *ReferenceContentAndSnippets,
) {
	saveWg.Add(1)
	defer saveWg.Done()

	for ref := range saveRefChan {
		rc := table.ReferenceContent
		upsertRefContent := rc.INSERT(
			rc.ReferenceID,
			rc.ContentType,
			rc.Content,
			rc.NormalizedContent,
			rc.ParsedContent,
			rc.Summary,
			rc.Title,
			rc.LastSuccessfulFetch,
		).MODEL(ref).ON_CONFLICT(
			rc.ReferenceID,
		).DO_UPDATE(
			postgres.SET(
				rc.ContentType.SET(
					rc.EXCLUDED.ContentType,
				),
				rc.Content.SET(
					rc.EXCLUDED.Content,
				),
				rc.NormalizedContent.SET(
					rc.EXCLUDED.NormalizedContent,
				),
				rc.Title.SET(
					rc.EXCLUDED.Title,
				),
				rc.LastSuccessfulFetch.SET(
					rc.EXCLUDED.LastSuccessfulFetch,
				),
				rc.ParsedContent.SET(
					rc.EXCLUDED.ParsedContent,
				),
				rc.Summary.SET(
					rc.EXCLUDED.Summary,
				),
			),
		)

		_, err := upsertRefContent.Exec(p.deps.DB)
		if err != nil {
			log.Error().
				Err(err).
				Msg("failed to upsert reference Content")
			continue
		}
		// todo: inserting all the snippets at once wasn't working so this just loops through and does one at a time. More Jet knowledge would probably make it easy to do concurrently
		for _, snippet := range ref.snippets {
			snp := table.CodeSnippet
			insertSnipsQuery := snp.INSERT(
				snp.Vulnerability,
				snp.ReferenceID,
				snp.Code,
				snp.Summary,
				snp.Score,
				snp.SourceURL,
				snp.Language,
				snp.Type,
			).MODEL(
				snippet,
			).ON_CONFLICT(
				// todo: I don't know if this will match the unique index conflict properly, doesnt seem possible to do conflicts on named unique indexes in jet
				snp.Vulnerability, snp.Code,
			).DO_NOTHING()

			_, err := insertSnipsQuery.Exec(p.deps.DB)
			if err != nil {
				log.Error().
					Err(err).
					Msg("failed to insert new snippets")
				continue
			}
		}

	}

}

func (p *scraper) ScrapeVulnerabilities(ecosystem, vulnID string, onlyUnfetchedContent bool) error {

	query := table.Vulnerability.SELECT(
		table.Vulnerability.ID.AS("VulnerabilityID"),
		table.Vulnerability.Details.AS("Details"),
		table.Vulnerability.Summary.AS("Summary"),
		table.Reference.ID.AS("ReferenceID"),
		table.Reference.URL.AS("ReferenceURL"),
	).FROM(
		table.Vulnerability.INNER_JOIN(
			table.Reference, table.Reference.VulnerabilityID.EQ(table.Vulnerability.ID),
		).INNER_JOIN(
			table.Affected, table.Affected.VulnerabilityID.EQ(table.Vulnerability.ID),
		).INNER_JOIN(
			packschem.Package, packschem.Package.ID.EQ(table.Affected.PackageID),
		).LEFT_JOIN(
			table.ReferenceContent, table.ReferenceContent.ReferenceID.EQ(table.Reference.ID),
		),
	).ORDER_BY(table.Vulnerability.SourceID.DESC())

	var whereClauses []postgres.BoolExpression
	if ecosystem != "" {
		whereClauses = append(whereClauses, packschem.Package.PackageManager.EQ(postgres.NewEnumValue(ecosystem)))
	}
	if vulnID != "" {
		whereClauses = append(whereClauses, table.Vulnerability.SourceID.EQ(postgres.String(vulnID)))
	}
	if onlyUnfetchedContent {
		whereClauses = append(whereClauses, table.ReferenceContent.LastSuccessfulFetch.IS_NOT_NULL())
	}

	if len(whereClauses) > 0 {
		log.Info().Any("vulnid", vulnID).Msg("building query filter")

		whereClause := whereClauses[0]
		for i := 1; i < len(whereClauses); i++ {
			whereClause = whereClause.AND(whereClauses[i])
		}
		query = query.WHERE(whereClause)
	}

	log.Info().Msg("filtering query")

	var results []struct {
		//VulnerabilityInfo
		ReferenceID     uuid.UUID
		ReferenceURL    string
		Details         string
		Summary         string
		VulnerabilityId uuid.UUID
	}
	err := query.Query(p.deps.DB, &results)
	if err != nil {
		log.Error().Err(err).Msg("failed to get vulnerability rows")
		return err
	}

	log.Info().Int("number_of_references", len(results)).Msg("fetched references")

	bar := progressbar.Default(int64(len(results)))

	refScrapeChan := make(chan *ReferenceInfo, 100)
	saveRefChan := make(chan *ReferenceContentAndSnippets, 100)

	var (
		wg     sync.WaitGroup
		saveWg sync.WaitGroup
	)

	for i := 0; i < p.deps.Workers; i++ {
		wg.Add(1)
		go func() {
			err := p.processVulnerabilityWorker(&wg, refScrapeChan, saveRefChan)
			if err != nil {
				log.Error().Err(err).Msg("failed to process vulnerability worker")
			}
		}()
	}

	go p.updateOrCreateRefContentWorker(&saveWg, saveRefChan)

	for _, vulnInfo := range results {
		bar.Add(1)

		err = p.referenceContentAlreadyExists(vulnInfo.ReferenceID.String())
		if err == nil {
			var (
				title   string
				content string
			)

			if vulnInfo.Summary != "" {
				title = vulnInfo.Summary
			}
			if vulnInfo.Details != "" {
				content = vulnInfo.Details
			}

			vulnRef := ReferenceContentAndSnippets{
				ReferenceContent: model.ReferenceContent{
					ReferenceID: vulnInfo.ReferenceID,
					Title:       title,
					Content:     content,
				},
			}
			saveRefChan <- &vulnRef
		}

		// TODO check the error to make sure it is just a "not found" error
		// "operator does not exist: uuid = text at character"

		// send this reference to be scraped
		refScrapeChan <- &ReferenceInfo{
			Reference: model.Reference{
				ID:              vulnInfo.ReferenceID,
				VulnerabilityID: vulnInfo.VulnerabilityId,
				URL:             vulnInfo.ReferenceURL,
			},
			VulnDesc: vulnInfo.Details,
		}
	}

	close(refScrapeChan)
	log.Info().Msg("waiting for workers to finish")
	wg.Wait()

	close(saveRefChan)
	log.Info().Msg("waiting for references to finish saving")
	saveWg.Wait()

	return nil
}

func (p *scraper) LoadAndOutputToDir(cache string, outputDir string, markdown bool) error {
	db, err := loadGormDB(cache)
	if err != nil {
		return err
	}

	err = os.MkdirAll(outputDir, 0755)
	if err != nil {
		return err
	}

	rows, err := db.Table("processed_references").Rows()
	if err != nil {
		return err
	}
	defer rows.Close()

	for rows.Next() {
		var ref ProcessedReference
		err = db.ScanRows(rows, &ref)
		if err != nil {
			log.Error().Err(err).Msg("failed to scan reference")
			continue
		}

		var (
			content []byte
			ext     string
		)

		contentReader := strings.NewReader(ref.Content)

		parsedUrl, err := url.Parse(ref.URL)
		if err != nil {
			continue
		}

		article, err := readability.FromReader(contentReader, parsedUrl)
		if err != nil {
			log.Error().Err(err).Msg("failed to parse html body")
			continue
		}

		if markdown {
			ext = ".md"

			strContent, err := formatContentAsMarkdown(ref.Content, ref.URL)
			if err != nil {
				log.Warn().Err(err).Msg("failed to convert reference to markdown")
				content = []byte("# " + ref.Title + "\n\n" + "## Vulnerability" + "\n[[" + ref.VulnerabilityID + "]]\n\n" + article.TextContent)
			} else {
				content = []byte(strContent)
			}
		} else {
			ext = ".json"

			ref.Title = article.Title
			ref.Content = article.TextContent

			content, err = json.Marshal(ref)
			if err != nil {
				log.Error().Err(err).Msg("failed to serialize reference")
				continue
			}
		}

		err = os.WriteFile(path.Join(outputDir, slugify.Slugify(ref.URL)+ext), content, 0644)
		if err != nil {
			log.Error().Err(err).Msg("failed to write reference")
			continue
		}
	}
	return nil
}

func NewScraper(deps scraperDeps) Scraper {
	tr := &http.Transport{
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
	}

	client := &http.Client{Timeout: time.Second * 5, Transport: tr}

	parsedBrowserDomains := strings.Split(strings.ReplaceAll(deps.BrowserDomains, " ", ""), ",")

	return &scraper{
		deps:           deps,
		httpClient:     client,
		browserDomains: parsedBrowserDomains,
	}
}
